import numpy as np
from typing import Tuple, Optional
from gps.kde import kde_gps
from utils.kernel import ColumnWiseGaussianKernel, AbsKernel, BinaryKernel, GaussianKernel, e_kernel,gaussian_kernel


def get_kernel_func(data_name: str) -> Tuple[AbsKernel, AbsKernel, AbsKernel, AbsKernel]:
    if data_name == "mw":
        return BinaryKernel(), GaussianKernel(), GaussianKernel(), GaussianKernel()
    elif data_name =='hypo':
        return GaussianKernel(), GaussianKernel(), GaussianKernel(), GaussianKernel()
    else:
        return ColumnWiseGaussianKernel(), ColumnWiseGaussianKernel(), ColumnWiseGaussianKernel(), ColumnWiseGaussianKernel()


class PMMRModel:
    treatment_kernel_func: AbsKernel
    treatment_proxy_kernel_func: AbsKernel
    outcome_proxy_kernel_func: AbsKernel
    backdoor_kernel_func: AbsKernel

    alpha: np.ndarray
    x_mean_vec: Optional[np.ndarray]
    w_mean_vec: np.ndarray
    train_treatment: np.ndarray
    train_outcome_proxy: np.ndarray

    def __init__(self, lamh1=0.2, lamh2=0.2, lamq1=0.2, lamq2=0.2, scale=0.25):
        self.lamh1 = lamh1
        self.lamh2 = lamh2
        self.lamq1 = lamq1
        self.lamq2 = lamq2
        self.scale = scale
        self.x_mean_vec_h = None
        self.x_mean_vec_q = None

        self._pretrain()

    
    def _pretrain(self,):
        kernels = get_kernel_func('hypo')
        self.treatment_kernel_func = kernels[0]
        self.treatment_proxy_kernel_func = kernels[1]
        self.outcome_proxy_kernel_func = kernels[2]
        self.backdoor_kernel_func = kernels[3]


    def fit_h(self,A,W,Z,Y,X):
        treatment_proxy = Z
        treatment = A
        outcome_proxy = W
        backdoor = X
        outcome = Y

        # Set scales to be median
        self.treatment_proxy_kernel_func.fit(treatment_proxy, scale=self.scale) # Z
        self.treatment_kernel_func.fit(treatment, scale=self.scale)  # A
        self.outcome_proxy_kernel_func.fit(outcome_proxy, scale=self.scale) # W

        if backdoor is not None:
            self.backdoor_kernel_func.fit(backdoor, scale=self.scale)

        treatment_mat = self.treatment_kernel_func.cal_kernel_mat(treatment, treatment) # A, Z,
        treatment_proxy_mat = self.treatment_proxy_kernel_func.cal_kernel_mat(treatment_proxy, treatment_proxy)
        outcome_proxy_mat = self.outcome_proxy_kernel_func.cal_kernel_mat(outcome_proxy, outcome_proxy)
        backdoor_mat = np.ones((len(Y), len(Y)))

        
        if backdoor is not None:
            backdoor_mat = self.backdoor_kernel_func.cal_kernel_mat(backdoor, backdoor)
            self.x_mean_vec_h = np.mean(backdoor_mat, axis=0)[:, np.newaxis]
        
        W = treatment_mat * treatment_proxy_mat * backdoor_mat  # (A,Z,X)
        L = treatment_mat * outcome_proxy_mat * backdoor_mat    # (A,W,X)

        self.alpha_h = np.linalg.solve(L @ W @ L + self.lamh1 * len(Y)* L + self.lamh2 * len(Y) * np.eye(len(Y)),
                                        L @ W @ outcome)
        

        self.train_treatment = treatment
        self.train_outcome_proxy = outcome_proxy
        if backdoor is not None:
            self.train_backdoor = backdoor


    def fit_q(self,A,W,Z,X):
        treatment_proxy = Z
        treatment = A
        outcome_proxy = W
        backdoor = X
        gps = kde_gps(A,W,X)[:, np.newaxis] 
        outcome = gps

        
        # Set scales to be median
        self.treatment_proxy_kernel_func.fit(treatment_proxy, scale=self.scale)  # Z
        self.treatment_kernel_func.fit(treatment, scale=self.scale)  # A
        self.outcome_proxy_kernel_func.fit(outcome_proxy, scale=self.scale)   # W

        if backdoor is not None:
            self.backdoor_kernel_func.fit(backdoor, scale=self.scale)

        treatment_mat = self.treatment_kernel_func.cal_kernel_mat(treatment, treatment)
        treatment_proxy_mat = self.treatment_proxy_kernel_func.cal_kernel_mat(treatment_proxy, treatment_proxy)
        outcome_proxy_mat = self.outcome_proxy_kernel_func.cal_kernel_mat(outcome_proxy, outcome_proxy)
        
        backdoor_mat = np.ones((len(Z), len(Z)))
        
        if backdoor is not None:
            backdoor_mat = self.backdoor_kernel_func.cal_kernel_mat(backdoor, backdoor)
            self.x_mean_vec_q = np.mean(backdoor_mat, axis=0)[:, np.newaxis]
        
        L = treatment_mat * treatment_proxy_mat * backdoor_mat
        W = treatment_mat * outcome_proxy_mat * backdoor_mat
        
        self.alpha_q = np.linalg.solve(L @ W @ L + self.lamq1 * len(Z)* L + self.lamq2 * len(Z) * np.eye(len(Z)),
                                        L @ W @ outcome)


        self.train_treatment = treatment
        self.train_treatment_proxy = treatment_proxy
        if backdoor is not None:
            self.train_backdoor = backdoor

    
    def predict_h(self, pointA,W,X=None) -> np.ndarray:
        ATE_list = []
        for a in pointA:
            a_full = np.full((len(W), 1), a)
            test_kernel = self.treatment_kernel_func.cal_kernel_mat(self.train_treatment, a_full)
            test_kernel *= self.outcome_proxy_kernel_func.cal_kernel_mat(self.train_outcome_proxy, W)
            if X is not None:
                test_kernel *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)
            
            pred = (self.alpha_h.T @ test_kernel).T
            ATE_list.append(np.mean(pred))

        return ATE_list
    
    def predict_h_mul(self, pointA,W,X=None) -> np.ndarray:
        ATE_list = []
        for a in pointA:
            a_full = np.full((len(W), 2), a)
            test_kernel = self.treatment_kernel_func.cal_kernel_mat(self.train_treatment, a_full)
            test_kernel *= self.outcome_proxy_kernel_func.cal_kernel_mat(self.train_outcome_proxy, W)
            if X is not None:
                test_kernel *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)
            
            pred = (self.alpha_h.T @ test_kernel).T
            ATE_list.append(np.mean(pred))
        return ATE_list
    

    def predict_q(self, pointA, A,Z, Y,X=None) -> np.ndarray:
        ATE_list = []
        Y = Y.ravel() if Y.ndim == 2 else Y
        A = A.ravel() if A.ndim == 2 else A
        
        bandwidth = 1.5*np.std(A)*(len(A)**-0.2)
        # print(bandwidth)
        
        for a in pointA:
            a_full = np.full((len(A), 1), a)

            test_kernel = self.treatment_kernel_func.cal_kernel_mat(self.train_treatment,a_full)
            test_kernel *= self.treatment_proxy_kernel_func.cal_kernel_mat(self.train_treatment_proxy, Z)
            
            if X is not None:
                test_kernel *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)

            q = (self.alpha_q.T @ test_kernel).T 
            
            q = q.ravel() if q.ndim == 2 else q

            # q[q<0.01] = 0.01

            ATE = np.mean(gaussian_kernel(A-a,bandwidth)*Y*q)
            ATE_list.append(ATE)
            
        return ATE_list
    

    def predict_q_mul(self, pointA, A,Z, Y,X=None) -> np.ndarray:
        ATE_list = []
        Y = Y.ravel() if Y.ndim == 2 else Y
        
        bandwidth0 = 1.5*np.std(A[:,0])*(len(Z)**-0.2)
        bandwidth1 = 1.5*np.std(A[:,1])*(len(Z)**-0.2)
        
        for a in pointA:
            a_full = np.full((len(A), 2), a)

            test_kernel = self.treatment_kernel_func.cal_kernel_mat(self.train_treatment,a_full)
            test_kernel *= self.treatment_proxy_kernel_func.cal_kernel_mat(self.train_treatment_proxy, Z)
            
            if X is not None:
                test_kernel *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)

            q = (self.alpha_q.T @ test_kernel).T 
            
            q = q.ravel() if q.ndim == 2 else q

            # q[q<0.01] = 0.01

            ATE = np.mean(e_kernel(A[:,0]-a,bandwidth0)*e_kernel(A[:,1]-a,bandwidth1)*Y*q)
            ATE_list.append(ATE)
        return ATE_list
    

    def drtest(self, pointA, A,Z, W, Y,X=None) -> np.ndarray:
        ATE_list = []
        Y = Y.ravel() if Y.ndim == 2 else Y
        A = A.ravel() if A.ndim == 2 else A
        # bandwidth = 1.5*np.std(A)*(len(A)**-0.2)
        bandwidth = 0.5


        for a in pointA:
            a_full = np.full((len(A), 1), a)
            test_kernel = self.treatment_kernel_func.cal_kernel_mat(self.train_treatment,a_full)
            
            test_kernel_h = test_kernel * self.outcome_proxy_kernel_func.cal_kernel_mat(self.train_outcome_proxy, W)
            if X is not None:
                test_kernel_h *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)
            
            pred_h = (self.alpha_h.T @ test_kernel_h).T

            test_kernel_q = test_kernel * self.treatment_proxy_kernel_func.cal_kernel_mat(self.train_treatment_proxy, Z)
            if X is not None:
                test_kernel_q *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)

            pred_q = (self.alpha_q.T @ test_kernel_q).T 

            pred_h = pred_h.ravel() if pred_h.ndim == 2 else pred_h
            pred_q = pred_q.ravel() if pred_q.ndim == 2 else pred_q

            
            ATE = np.mean((Y-pred_h)*pred_q*gaussian_kernel(A-a,bandwidth) + pred_h) 
            ATE_list.append(ATE)
        
        return ATE_list
    

    def drtest_mul(self, pointA, A,Z, W, Y,X=None) -> np.ndarray:
        ATE_list = []
        Y = Y.ravel() if Y.ndim == 2 else Y
    
        bandwidth0 = 1.5*np.std(A[:,0])*(len(Z)**-0.2)
        bandwidth1 = 1.5*np.std(A[:,1])*(len(Z)**-0.2)

        for a in pointA:
            a_full = np.full((len(A), 2), a)
            test_kernel = self.treatment_kernel_func.cal_kernel_mat(self.train_treatment,a_full)
            
            test_kernel_h = test_kernel * self.outcome_proxy_kernel_func.cal_kernel_mat(self.train_outcome_proxy, W)
            if X is not None:
                test_kernel_h *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)
            
            pred_h = (self.alpha_h.T @ test_kernel_h).T

            test_kernel_q = test_kernel * self.treatment_proxy_kernel_func.cal_kernel_mat(self.train_treatment_proxy, Z)
            if X is not None:
                test_kernel_q *= self.backdoor_kernel_func.cal_kernel_mat(self.train_backdoor, X)

            pred_q = (self.alpha_q.T @ test_kernel_q).T 

            pred_h = pred_h.ravel() if pred_h.ndim == 2 else pred_h
            pred_q = pred_q.ravel() if pred_q.ndim == 2 else pred_q

            pred_q[pred_q<0.01] = 0.01
            
            ATE = np.mean((Y-pred_h)*pred_q*e_kernel(A[:,0]-a,bandwidth0)*e_kernel(A[:,1]-a,bandwidth1) + pred_h) 
            ATE_list.append(ATE)
        
        return ATE_list
            

